// Copyright 2013 <chaishushan{AT}gmail.com>. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package main

import (
	"bytes"
	"os"
	"strings"
	"text/template"

	descriptor "code.google.com/p/goprotobuf/protoc-gen-go/descriptor"
	generator "code.google.com/p/goprotobuf/protoc-gen-go/generator"
)

// option go_generic_services = ???;
const go_generic_services = "go_generic_services"

// pkg name as prefix of service name?
const go_generic_services_use_pkg_name = "go_generic_services_use_pkg_name"

// servicePlugin produce the Service interface.
type servicePlugin struct {
	*generator.Generator
}

// Name returns the name of the plugin.
func (p *servicePlugin) Name() string { return "protorpc" }

// Init is called once after data structures are built but before
// code generation begins.
func (p *servicePlugin) Init(g *generator.Generator) {
	p.Generator = g
}

// Generate produces the code generated by the plugin for this file.
func (p *servicePlugin) GenerateImports(file *generator.FileDescriptor) {
	if !p.getGenericServicesOptions(file) {
		return
	}
	if len(file.Service) > 0 {
		p.P(`import "io"`)
		p.P(`import "log"`)
		p.P(`import "net"`)
		p.P(`import "net/rpc"`)
		p.P(`import "time"`)
		p.P(`import protorpc "code.google.com/p/protorpc"`)
	}
}

// Generate generates the Service interface.
// rpc service can't handle other proto message!!!
func (p *servicePlugin) Generate(file *generator.FileDescriptor) {
	if !p.getGenericServicesOptions(file) {
		return
	}
	for _, svc := range file.Service {
		p.genServiceInterface(file, svc)
		p.genServiceServer(file, svc)
		p.genServiceClient(file, svc)
	}
}

func (p *servicePlugin) getGenericServicesOptions(
	file *generator.FileDescriptor,
) bool {
	env := go_generic_services

	// try command line first
	// protoc --go_out=go_generic_services=true:. xxx.proto
	if value, ok := p.Generator.Param[env]; ok {
		if value == "1" || strings.ToLower(value) == "true" {
			return true
		}
		if value == "0" || strings.ToLower(value) == "false" {
			return false
		}
	}

	// try environment second
	if value := os.Getenv(strings.ToUpper(env)); value != "" {
		if value == "1" || strings.ToLower(value) == "true" {
			return true
		}
		if value == "0" || strings.ToLower(value) == "false" {
			return false
		}
	}
	if value := os.Getenv(strings.ToLower(env)); value != "" {
		if value == "1" || strings.ToLower(value) == "true" {
			return true
		}
		if value == "0" || strings.ToLower(value) == "false" {
			return false
		}
	}

	// try proto file option last
	if file.GetOptions().GetCcGenericServices() {
		return true
	}
	return false
}

func (p *servicePlugin) getGenericServicesOptionsUsePkgName(
	file *generator.FileDescriptor,
) bool {
	env := go_generic_services_use_pkg_name

	// try command line first
	// protoc --go_out=go_generic_services_use_pkg_name=true:. xxx.proto
	if value, ok := p.Generator.Param[env]; ok {
		if value == "1" || strings.ToLower(value) == "true" {
			return true
		}
		if value == "0" || strings.ToLower(value) == "false" {
			return false
		}
	}

	if value := os.Getenv(strings.ToUpper(env)); value != "" {
		if value == "1" || strings.ToLower(value) == "true" {
			return true
		}
		if value == "0" || strings.ToLower(value) == "false" {
			return false
		}
	}
	if value := os.Getenv(strings.ToLower(env)); value != "" {
		if value == "1" || strings.ToLower(value) == "true" {
			return true
		}
		if value == "0" || strings.ToLower(value) == "false" {
			return false
		}
	}

	return false
}

func (p *servicePlugin) genServiceInterface(
	file *generator.FileDescriptor,
	svc *descriptor.ServiceDescriptorProto,
) {
	const serviceInterfaceTmpl = `
type {{.ServiceName}} interface {
	{{.CallMethodList}}
}
`
	const callMethodTmpl = `
{{.MethodName}}(in *{{.ArgsType}}, out *{{.ReplyType}}) error`

	// gen call method list
	var callMethodList string
	for _, m := range svc.Method {
		out := bytes.NewBuffer([]byte{})
		t := template.Must(template.New("").Parse(callMethodTmpl))
		t.Execute(out, &struct{ ServiceName, MethodName, ArgsType, ReplyType string }{
			ServiceName: generator.CamelCase(svc.GetName()),
			MethodName:  generator.CamelCase(m.GetName()),
			ArgsType:    p.TypeName(p.ObjectNamed(m.GetInputType())),
			ReplyType:   p.TypeName(p.ObjectNamed(m.GetOutputType())),
		})
		callMethodList += out.String()

		p.RecordTypeUse(m.GetInputType())
		p.RecordTypeUse(m.GetOutputType())
	}

	// gen all interface code
	{
		out := bytes.NewBuffer([]byte{})
		t := template.Must(template.New("").Parse(serviceInterfaceTmpl))
		t.Execute(out, &struct{ ServiceName, CallMethodList string }{
			ServiceName:    generator.CamelCase(svc.GetName()),
			CallMethodList: callMethodList,
		})
		p.P(out.String())
	}
}

func (p *servicePlugin) genServiceServer(
	file *generator.FileDescriptor,
	svc *descriptor.ServiceDescriptorProto,
) {
	const serviceHelperFunTmpl = `
// Accept{{.ServiceName}}Client accepts connections on the listener and serves requests
// for each incoming connection.  Accept blocks; the caller typically
// invokes it in a go statement.
func Accept{{.ServiceName}}Client(lis net.Listener, x {{.ServiceName}}) {
	srv := rpc.NewServer()
	if err := srv.RegisterName("{{.ServiceRegisterName}}", x); err != nil {
		log.Fatal(err)
	}

	for {
		conn, err := lis.Accept()
		if err != nil {
			log.Fatalf("lis.Accept(): %v\n", err)
		}
		go srv.ServeCodec(protorpc.NewServerCodec(conn))
	}
}

// Register{{.ServiceName}} publish the given {{.ServiceName}} implementation on the server.
func Register{{.ServiceName}}(srv *rpc.Server, x {{.ServiceName}}) error {
	if err := srv.RegisterName("{{.ServiceRegisterName}}", x); err != nil {
		return err
	}
	return nil
}

// New{{.ServiceName}}Server returns a new {{.ServiceName}} Server.
func New{{.ServiceName}}Server(x {{.ServiceName}}) *rpc.Server {
	srv := rpc.NewServer()
	if err := srv.RegisterName("{{.ServiceRegisterName}}", x); err != nil {
		log.Fatal(err)
	}
	return srv
}

// ListenAndServe{{.ServiceName}} listen announces on the local network address laddr
// and serves the given {{.ServiceName}} implementation.
func ListenAndServe{{.ServiceName}}(network, addr string, x {{.ServiceName}}) error {
	lis, err := net.Listen(network, addr)
	if err != nil {
		return err
	}
	defer lis.Close()

	srv := rpc.NewServer()
	if err := srv.RegisterName("{{.ServiceRegisterName}}", x); err != nil {
		return err
	}

	for {
		conn, err := lis.Accept()
		if err != nil {
			log.Fatalf("lis.Accept(): %v\n", err)
		}
		go srv.ServeCodec(protorpc.NewServerCodec(conn))
	}
}
`
	{
		out := bytes.NewBuffer([]byte{})
		t := template.Must(template.New("").Parse(serviceHelperFunTmpl))
		t.Execute(out, &struct{ PackageName, ServiceName, ServiceRegisterName string }{
			PackageName: file.GetPackage(),
			ServiceName: generator.CamelCase(svc.GetName()),
			ServiceRegisterName: p.makeServiceRegisterName(
				file, file.GetPackage(), generator.CamelCase(svc.GetName()),
			),
		})
		p.P(out.String())
	}
}

func (p *servicePlugin) genServiceClient(
	file *generator.FileDescriptor,
	svc *descriptor.ServiceDescriptorProto,
) {
	const clientHelperFuncTmpl = `
type {{.ServiceName}}Client struct {
	*rpc.Client
}

// New{{.ServiceName}}Client returns a {{.ServiceName}} rpc.Client and stub to handle
// requests to the set of {{.ServiceName}} at the other end of the connection.
func New{{.ServiceName}}Client(conn io.ReadWriteCloser) (*{{.ServiceName}}Client, *rpc.Client) {
	c := rpc.NewClientWithCodec(protorpc.NewClientCodec(conn))
	return &{{.ServiceName}}Client{c}, c
}

{{.MethodList}}

// Dial{{.ServiceName}} connects to an {{.ServiceName}} at the specified network address.
func Dial{{.ServiceName}}(network, addr string) (*{{.ServiceName}}Client, *rpc.Client, error) {
	c, err := protorpc.Dial(network, addr)
	if err != nil {
		return nil, nil, err
	}
	return &{{.ServiceName}}Client{c}, c, nil
}

// Dial{{.ServiceName}}Timeout connects to an {{.ServiceName}} at the specified network address.
func Dial{{.ServiceName}}Timeout(network, addr string,
	timeout time.Duration) (*{{.ServiceName}}Client, *rpc.Client, error) {
	c, err := protorpc.DialTimeout(network, addr, timeout)
	if err != nil {
		return nil, nil, err
	}
	return &{{.ServiceName}}Client{c}, c, nil
}
`
	const clientMethodTmpl = `
func (c *{{.ServiceName}}Client) {{.MethodName}}(in *{{.ArgsType}}, out *{{.ReplyType}}) error {
	return c.Call("{{.ServiceRegisterName}}.{{.MethodName}}", in, out)
}`

	// gen client method list
	var methodList string
	for _, m := range svc.Method {
		out := bytes.NewBuffer([]byte{})
		t := template.Must(template.New("").Parse(clientMethodTmpl))
		t.Execute(out, &struct{ ServiceName, ServiceRegisterName, MethodName, ArgsType, ReplyType string }{
			ServiceName: generator.CamelCase(svc.GetName()),
			ServiceRegisterName: p.makeServiceRegisterName(
				file, file.GetPackage(), generator.CamelCase(svc.GetName()),
			),
			MethodName: generator.CamelCase(m.GetName()),
			ArgsType:   p.TypeName(p.ObjectNamed(m.GetInputType())),
			ReplyType:  p.TypeName(p.ObjectNamed(m.GetOutputType())),
		})
		methodList += out.String()
	}

	// gen all client code
	{
		out := bytes.NewBuffer([]byte{})
		t := template.Must(template.New("").Parse(clientHelperFuncTmpl))
		t.Execute(out, &struct{ PackageName, ServiceName, MethodList string }{
			PackageName: file.GetPackage(),
			ServiceName: generator.CamelCase(svc.GetName()),
			MethodList:  methodList,
		})
		p.P(out.String())
	}
}

func (p *servicePlugin) makeServiceRegisterName(
	file *generator.FileDescriptor,
	packageName, serviceName string,
) string {
	if p.getGenericServicesOptionsUsePkgName(file) {
		return packageName + "." + serviceName
	}
	return serviceName
}

func init() {
	generator.RegisterPlugin(new(servicePlugin))
}
